import abc
import functools
import random
from datetime import datetime
import math
import sys
import logging
from typing import (
    Any,
    Callable,
    ClassVar,
    Dict,
    Generator,
    Iterator,
    NewType,
    Tuple,
    Optional,
    Sequence,
    Type,
    List,
    Union,
    TypeVar,
)
from functools import partial, wraps
from concurrent.futures import ThreadPoolExecutor
import threading
from abc import abstractmethod
from uuid import UUID
import decimal
import contextvars

import attrs
from typing_extensions import Self

from data_diff.abcs.compiler import AbstractCompiler, Compilable
from data_diff.queries.extras import ApplyFuncAndNormalizeAsString, Checksum, NormalizeAsString
from data_diff.schema import RawColumnInfo
from data_diff.utils import ArithString, ArithUUID, is_uuid, join_iter, safezip
from data_diff.queries.api import Expr, table, Select, SKIP, Explain, Code, this
from data_diff.queries.ast_classes import (
    Alias,
    BinOp,
    CaseWhen,
    Cast,
    Column,
    Commit,
    Concat,
    ConstantTable,
    Count,
    CreateTable,
    Cte,
    CurrentTimestamp,
    DropTable,
    Func,
    GroupBy,
    ITable,
    In,
    InsertToTable,
    IsDistinctFrom,
    Join,
    Param,
    Random,
    Root,
    TableAlias,
    TableOp,
    TablePath,
    TruncateTable,
    UnaryOp,
    WhenThen,
    _ResolveColumn,
)
from data_diff.abcs.database_types import (
    Array,
    ColType_UUID,
    FractionalType,
    Struct,
    ColType,
    Integer,
    Decimal,
    Float,
    Native_UUID,
    String_UUID,
    String_Alphanum,
    String_VaryingAlphanum,
    TemporalType,
    UnknownColType,
    TimestampTZ,
    Text,
    DbTime,
    DbPath,
    Boolean,
    JSON,
)

logger = logging.getLogger("database")
cv_params = contextvars.ContextVar("params")


class CompileError(Exception):
    pass


@attrs.define(frozen=True)
class Compiler(AbstractCompiler):
    """
    Compiler bears the context for a single compilation.

    There can be multiple compilation per app run.
    There can be multiple compilers in one compilation (with varying contexts).
    """

    # Database is needed to normalize tables. Dialect is needed for recursive compilations.
    # In theory, it is many-to-many relations: e.g. a generic ODBC driver with multiple dialects.
    # In practice, we currently bind the dialects to the specific database classes.
    database: "Database"

    in_select: bool = False  # Compilation runtime flag
    in_join: bool = False  # Compilation runtime flag

    _table_context: List = attrs.field(factory=list)  # List[ITable]
    _subqueries: Dict[str, Any] = attrs.field(factory=dict)  # XXX not thread-safe
    root: bool = True

    _counter: List = attrs.field(factory=lambda: [0])

    @property
    def dialect(self) -> "BaseDialect":
        return self.database.dialect

    # TODO: DEPRECATED: Remove once the dialect is used directly in all places.
    def compile(self, elem, params=None) -> str:
        return self.dialect.compile(self, elem, params)

    def new_unique_name(self, prefix="tmp") -> str:
        self._counter[0] += 1
        return f"{prefix}{self._counter[0]}"

    def new_unique_table_name(self, prefix="tmp") -> DbPath:
        self._counter[0] += 1
        table_name = f"{prefix}{self._counter[0]}_{'%x'%random.randrange(2**32)}"
        return self.database.dialect.parse_table_name(table_name)

    def add_table_context(self, *tables: Sequence, **kw) -> Self:
        return attrs.evolve(self, table_context=self._table_context + list(tables), **kw)


def parse_table_name(t):
    return tuple(t.split("."))


def import_helper(package: str = None, text=""):
    def dec(f):
        @wraps(f)
        def _inner():
            try:
                return f()
            except ModuleNotFoundError as e:
                s = text
                if package:
                    s += f"Please complete setup by running: pip install 'data_diff[{package}]'."
                raise ModuleNotFoundError(f"{e}\n\n{s}\n")

        return _inner

    return dec


class ConnectError(Exception):
    pass


class QueryError(Exception):
    pass


def _one(seq):
    (x,) = seq
    return x


@attrs.define(frozen=False)
class ThreadLocalInterpreter:
    """An interpeter used to execute a sequence of queries within the same thread and cursor.

    Useful for cursor-sensitive operations, such as creating a temporary table.
    """

    compiler: Compiler
    gen: Generator

    def apply_queries(self, callback: Callable[[str], Any]) -> None:
        q: Expr = next(self.gen)
        while True:
            sql = self.compiler.database.dialect.compile(self.compiler, q)
            logger.debug("Running SQL (%s-TL):\n%s", self.compiler.database.name, sql)
            try:
                try:
                    res = callback(sql) if sql is not SKIP else SKIP
                except Exception as e:
                    q = self.gen.throw(type(e), e)
                else:
                    q = self.gen.send(res)
            except StopIteration:
                break


def apply_query(callback: Callable[[str], Any], sql_code: Union[str, ThreadLocalInterpreter]) -> list:
    if isinstance(sql_code, ThreadLocalInterpreter):
        return sql_code.apply_queries(callback)
    else:
        return callback(sql_code)


@attrs.define(frozen=False)
class BaseDialect(abc.ABC):
    SUPPORTS_PRIMARY_KEY: ClassVar[bool] = False
    SUPPORTS_INDEXES: ClassVar[bool] = False
    PREVENT_OVERFLOW_WHEN_CONCAT: ClassVar[bool] = False
    TYPE_CLASSES: ClassVar[Dict[str, Type[ColType]]] = {}
    DEFAULT_NUMERIC_PRECISION: ClassVar[int] = 0  # effective precision when type is just "NUMERIC"

    PLACEHOLDER_TABLE = None  # Used for Oracle

    # Some database do not support long string so concatenation might lead to type overflow

    _prevent_overflow_when_concat: bool = False

    def enable_preventing_type_overflow(self) -> None:
        logger.info("Preventing type overflow when concatenation is enabled")
        self._prevent_overflow_when_concat = True

    def parse_table_name(self, name: str) -> DbPath:
        "Parse the given table name into a DbPath"
        return parse_table_name(name)

    def compile(self, compiler: Compiler, elem, params=None) -> str:
        if params:
            cv_params.set(params)

        if compiler.root and isinstance(elem, Compilable) and not isinstance(elem, Root):
            from data_diff.queries.ast_classes import Select

            elem = Select(columns=[elem])

        res = self._compile(compiler, elem)
        if compiler.root and compiler._subqueries:
            subq = ", ".join(f"\n  {k} AS ({v})" for k, v in compiler._subqueries.items())
            compiler._subqueries.clear()
            return f"WITH {subq}\n{res}"
        return res

    def _compile(self, compiler: Compiler, elem) -> str:
        if elem is None:
            return "NULL"
        elif isinstance(elem, Compilable):
            return self.render_compilable(attrs.evolve(compiler, root=False), elem)
        elif isinstance(elem, ColType):
            return self.render_coltype(attrs.evolve(compiler, root=False), elem)
        elif isinstance(elem, str):
            return f"'{elem}'"
        elif isinstance(elem, (int, float)):
            return str(elem)
        elif isinstance(elem, datetime):
            return self.timestamp_value(elem)
        elif isinstance(elem, bytes):
            return f"b'{elem.decode()}'"
        elif isinstance(elem, ArithUUID):
            s = f"'{elem.uuid}'"
            return s.upper() if elem.uppercase else s.lower() if elem.lowercase else s
        elif isinstance(elem, ArithString):
            return f"'{elem}'"
        assert False, elem

    def render_compilable(self, c: Compiler, elem: Compilable) -> str:
        # All ifs are only for better code navigation, IDE usage detection, and type checking.
        # The last catch-all would render them anyway — it is a typical "visitor" pattern.
        if isinstance(elem, Column):
            return self.render_column(c, elem)
        elif isinstance(elem, Cte):
            return self.render_cte(c, elem)
        elif isinstance(elem, Commit):
            return self.render_commit(c, elem)
        elif isinstance(elem, Param):
            return self.render_param(c, elem)
        elif isinstance(elem, NormalizeAsString):
            return self.render_normalizeasstring(c, elem)
        elif isinstance(elem, ApplyFuncAndNormalizeAsString):
            return self.render_applyfuncandnormalizeasstring(c, elem)
        elif isinstance(elem, Checksum):
            return self.render_checksum(c, elem)
        elif isinstance(elem, Concat):
            return self.render_concat(c, elem)
        elif isinstance(elem, Func):
            return self.render_func(c, elem)
        elif isinstance(elem, WhenThen):
            return self.render_whenthen(c, elem)
        elif isinstance(elem, CaseWhen):
            return self.render_casewhen(c, elem)
        elif isinstance(elem, IsDistinctFrom):
            return self.render_isdistinctfrom(c, elem)
        elif isinstance(elem, UnaryOp):
            return self.render_unaryop(c, elem)
        elif isinstance(elem, BinOp):
            return self.render_binop(c, elem)
        elif isinstance(elem, TablePath):
            return self.render_tablepath(c, elem)
        elif isinstance(elem, TableAlias):
            return self.render_tablealias(c, elem)
        elif isinstance(elem, TableOp):
            return self.render_tableop(c, elem)
        elif isinstance(elem, Select):
            return self.render_select(c, elem)
        elif isinstance(elem, Join):
            return self.render_join(c, elem)
        elif isinstance(elem, GroupBy):
            return self.render_groupby(c, elem)
        elif isinstance(elem, Count):
            return self.render_count(c, elem)
        elif isinstance(elem, Alias):
            return self.render_alias(c, elem)
        elif isinstance(elem, In):
            return self.render_in(c, elem)
        elif isinstance(elem, Cast):
            return self.render_cast(c, elem)
        elif isinstance(elem, Random):
            return self.render_random(c, elem)
        elif isinstance(elem, Explain):
            return self.render_explain(c, elem)
        elif isinstance(elem, CurrentTimestamp):
            return self.render_currenttimestamp(c, elem)
        elif isinstance(elem, CreateTable):
            return self.render_createtable(c, elem)
        elif isinstance(elem, DropTable):
            return self.render_droptable(c, elem)
        elif isinstance(elem, TruncateTable):
            return self.render_truncatetable(c, elem)
        elif isinstance(elem, InsertToTable):
            return self.render_inserttotable(c, elem)
        elif isinstance(elem, Code):
            return self.render_code(c, elem)
        elif isinstance(elem, _ResolveColumn):
            return self.render__resolvecolumn(c, elem)

        method_name = f"render_{elem.__class__.__name__.lower()}"
        method = getattr(self, method_name, None)
        if method is not None:
            return method(c, elem)
        else:
            raise RuntimeError(f"Cannot render AST of type {elem.__class__}")
        # return elem.compile(compiler.replace(root=False))

    def render_coltype(self, c: Compiler, elem: ColType) -> str:
        return self.type_repr(elem)

    def render_column(self, c: Compiler, elem: Column) -> str:
        if c._table_context:
            if len(c._table_context) > 1:
                aliases = [
                    t for t in c._table_context if isinstance(t, TableAlias) and t.source_table is elem.source_table
                ]
                if not aliases:
                    return self.quote(elem.name)
                elif len(aliases) > 1:
                    raise CompileError(f"Too many aliases for column {elem.name}")
                (alias,) = aliases

                return f"{self.quote(alias.name)}.{self.quote(elem.name)}"

        return self.quote(elem.name)

    def render_cte(self, parent_c: Compiler, elem: Cte) -> str:
        c: Compiler = attrs.evolve(parent_c, table_context=[], in_select=False)
        compiled = self.compile(c, elem.source_table)

        name = elem.name or parent_c.new_unique_name()
        name_params = f"{name}({', '.join(elem.params)})" if elem.params else name
        parent_c._subqueries[name_params] = compiled

        return name

    def render_commit(self, c: Compiler, elem: Commit) -> str:
        return "COMMIT" if not c.database.is_autocommit else SKIP

    def render_param(self, c: Compiler, elem: Param) -> str:
        params = cv_params.get()
        return self._compile(c, params[elem.name])

    def render_normalizeasstring(self, c: Compiler, elem: NormalizeAsString) -> str:
        expr = self.compile(c, elem.expr)
        return self.normalize_value_by_type(expr, elem.expr_type or elem.expr.type)

    def render_applyfuncandnormalizeasstring(self, c: Compiler, elem: ApplyFuncAndNormalizeAsString) -> str:
        expr = elem.expr
        expr_type = expr.type

        if isinstance(expr_type, Native_UUID):
            # Normalize first, apply template after (for uuids)
            # Needed because min/max(uuid) fails in postgresql
            expr = NormalizeAsString(expr, expr_type)
            if elem.apply_func is not None:
                expr = elem.apply_func(expr)  # Apply template using Python's string formatting

        else:
            # Apply template before normalizing (for ints)
            if elem.apply_func is not None:
                expr = elem.apply_func(expr)  # Apply template using Python's string formatting
            expr = NormalizeAsString(expr, expr_type)

        return self.compile(c, expr)

    def render_checksum(self, c: Compiler, elem: Checksum) -> str:
        if len(elem.exprs) > 1:
            exprs = [Code(f"coalesce({self.compile(c, expr)}, '<null>')") for expr in elem.exprs]
            # exprs = [self.compile(c, e) for e in exprs]
            expr = Concat(exprs, "|")
        else:
            # No need to coalesce - safe to assume that key cannot be null
            (expr,) = elem.exprs
        expr = self.compile(c, expr)
        md5 = self.md5_as_int(expr)
        return f"sum({md5})"

    def render_concat(self, c: Compiler, elem: Concat) -> str:
        if self._prevent_overflow_when_concat:
            items = [
                f"{self.compile(c, Code(self.md5_as_hex(self.to_string(self.compile(c, expr)))))}"
                for expr in elem.exprs
            ]

        # We coalesce because on some DBs (e.g. MySQL) concat('a', NULL) is NULL
        else:
            items = [
                f"coalesce({self.compile(c, Code(self.to_string(self.compile(c, expr))))}, '<null>')"
                for expr in elem.exprs
            ]

        assert items
        if len(items) == 1:
            return items[0]

        if elem.sep:
            items = list(join_iter(f"'{elem.sep}'", items))
        return self.concat(items)

    def render_alias(self, c: Compiler, elem: Alias) -> str:
        return f"{self.compile(c, elem.expr)} AS {self.quote(elem.name)}"

    def render_count(self, c: Compiler, elem: Count) -> str:
        expr = self.compile(c, elem.expr) if elem.expr else "*"
        if elem.distinct:
            return f"count(distinct {expr})"
        return f"count({expr})"

    def render_code(self, c: Compiler, elem: Code) -> str:
        if not elem.args:
            return elem.code

        args = {k: self.compile(c, v) for k, v in elem.args.items()}
        return elem.code.format(**args)

    def render_func(self, c: Compiler, elem: Func) -> str:
        args = ", ".join(self.compile(c, e) for e in elem.args)
        return f"{elem.name}({args})"

    def render_whenthen(self, c: Compiler, elem: WhenThen) -> str:
        return f"WHEN {self.compile(c, elem.when)} THEN {self.compile(c, elem.then)}"

    def render_casewhen(self, c: Compiler, elem: CaseWhen) -> str:
        assert elem.cases
        when_thens = " ".join(self.compile(c, case) for case in elem.cases)
        else_expr = (" ELSE " + self.compile(c, elem.else_expr)) if elem.else_expr is not None else ""
        return f"CASE {when_thens}{else_expr} END"

    def render_isdistinctfrom(self, c: Compiler, elem: IsDistinctFrom) -> str:
        a = self.to_comparable(self.compile(c, elem.a), elem.a.type)
        b = self.to_comparable(self.compile(c, elem.b), elem.b.type)
        return self.is_distinct_from(a, b)

    def render_unaryop(self, c: Compiler, elem: UnaryOp) -> str:
        return f"({elem.op}{self.compile(c, elem.expr)})"

    def render_binop(self, c: Compiler, elem: BinOp) -> str:
        expr = f" {elem.op} ".join(self.compile(c, a) for a in elem.args)
        return f"({expr})"

    def render_tablepath(self, c: Compiler, elem: TablePath) -> str:
        path = elem.path  # c.database._normalize_table_path(self.name)
        return ".".join(map(self.quote, path))

    def render_tablealias(self, c: Compiler, elem: TableAlias) -> str:
        return f"{self.compile(c, elem.source_table)} {self.quote(elem.name)}"

    def render_tableop(self, parent_c: Compiler, elem: TableOp) -> str:
        c: Compiler = attrs.evolve(parent_c, in_select=False)
        table_expr = f"{self.compile(c, elem.table1)} {elem.op} {self.compile(c, elem.table2)}"
        if parent_c.in_select:
            table_expr = f"({table_expr}) {c.new_unique_name()}"
        elif parent_c.in_join:
            table_expr = f"({table_expr})"
        return table_expr

    def render__resolvecolumn(self, c: Compiler, elem: _ResolveColumn) -> str:
        return self.compile(c, elem._get_resolved())

    def render_select(self, parent_c: Compiler, elem: Select) -> str:
        c: Compiler = attrs.evolve(parent_c, in_select=True)  # .add_table_context(self.table)
        compile_fn = functools.partial(self.compile, c)

        columns = ", ".join(map(compile_fn, elem.columns)) if elem.columns else "*"
        distinct = "DISTINCT " if elem.distinct else ""
        optimizer_hints = self.optimizer_hints(elem.optimizer_hints) if elem.optimizer_hints else ""
        select = f"SELECT {optimizer_hints}{distinct}{columns}"

        if elem.table:
            select += " FROM " + self.compile(c, elem.table)
        elif self.PLACEHOLDER_TABLE:
            select += f" FROM {self.PLACEHOLDER_TABLE}"

        if elem.where_exprs:
            select += " WHERE " + " AND ".join(map(compile_fn, elem.where_exprs))

        if elem.group_by_exprs:
            select += " GROUP BY " + ", ".join(map(compile_fn, elem.group_by_exprs))

        if elem.having_exprs:
            assert elem.group_by_exprs
            select += " HAVING " + " AND ".join(map(compile_fn, elem.having_exprs))

        if elem.order_by_exprs:
            select += " ORDER BY " + ", ".join(map(compile_fn, elem.order_by_exprs))

        if elem.limit_expr is not None:
            has_order_by = bool(elem.order_by_exprs)
            select = self.limit_select(select_query=select, offset=0, limit=elem.limit_expr, has_order_by=has_order_by)

        if parent_c.in_select:
            select = f"({select}) {c.new_unique_name()}"
        elif parent_c.in_join:
            select = f"({select})"
        return select

    def render_join(self, parent_c: Compiler, elem: Join) -> str:
        tables = [
            t if isinstance(t, TableAlias) else TableAlias(t, name=parent_c.new_unique_name())
            for t in elem.source_tables
        ]
        c = parent_c.add_table_context(*tables, in_join=True, in_select=False)
        op = " JOIN " if elem.op is None else f" {elem.op} JOIN "
        joined = op.join(self.compile(c, t) for t in tables)

        if elem.on_exprs:
            on = " AND ".join(self.compile(c, e) for e in elem.on_exprs)
            res = f"{joined} ON {on}"
        else:
            res = joined

        compile_fn = functools.partial(self.compile, c)
        columns = "*" if elem.columns is None else ", ".join(map(compile_fn, elem.columns))
        select = f"SELECT {columns} FROM {res}"

        if parent_c.in_select:
            select = f"({select}) {c.new_unique_name()}"
        elif parent_c.in_join:
            select = f"({select})"
        return select

    def render_groupby(self, c: Compiler, elem: GroupBy) -> str:
        compile_fn = functools.partial(self.compile, c)

        if elem.values is None:
            raise CompileError(".group_by() must be followed by a call to .agg()")

        keys = [str(i + 1) for i in range(len(elem.keys))]
        columns = (elem.keys or []) + (elem.values or [])
        if isinstance(elem.table, Select) and elem.table.columns is None and elem.table.group_by_exprs is None:
            return self.compile(
                c,
                attrs.evolve(
                    elem.table,
                    columns=columns,
                    group_by_exprs=[Code(k) for k in keys],
                    having_exprs=elem.having_exprs,
                ),
            )

        keys_str = ", ".join(keys)
        columns_str = ", ".join(self.compile(c, x) for x in columns)
        having_str = (
            " HAVING " + " AND ".join(map(compile_fn, elem.having_exprs)) if elem.having_exprs is not None else ""
        )
        select = f"SELECT {columns_str} FROM {self.compile(attrs.evolve(c, in_select=True), elem.table)} GROUP BY {keys_str}{having_str}"

        if c.in_select:
            select = f"({select}) {c.new_unique_name()}"
        elif c.in_join:
            select = f"({select})"
        return select

    def render_in(self, c: Compiler, elem: In) -> str:
        compile_fn = functools.partial(self.compile, c)
        elems = ", ".join(map(compile_fn, elem.list))
        return f"({self.compile(c, elem.expr)} IN ({elems}))"

    def render_cast(self, c: Compiler, elem: Cast) -> str:
        return f"cast({self.compile(c, elem.expr)} as {self.compile(c, elem.target_type)})"

    def render_random(self, c: Compiler, elem: Random) -> str:
        return self.random()

    def render_explain(self, c: Compiler, elem: Explain) -> str:
        return self.explain_as_text(self.compile(c, elem.select))

    def render_currenttimestamp(self, c: Compiler, elem: CurrentTimestamp) -> str:
        return self.current_timestamp()

    def render_createtable(self, c: Compiler, elem: CreateTable) -> str:
        ne = "IF NOT EXISTS " if elem.if_not_exists else ""
        if elem.source_table:
            return f"CREATE TABLE {ne}{self.compile(c, elem.path)} AS {self.compile(c, elem.source_table)}"

        schema = ", ".join(f"{self.quote(k)} {self.type_repr(v)}" for k, v in elem.path.schema.items())
        pks = (
            ", PRIMARY KEY (%s)" % ", ".join(elem.primary_keys)
            if elem.primary_keys and self.SUPPORTS_PRIMARY_KEY
            else ""
        )
        return f"CREATE TABLE {ne}{self.compile(c, elem.path)}({schema}{pks})"

    def render_droptable(self, c: Compiler, elem: DropTable) -> str:
        ie = "IF EXISTS " if elem.if_exists else ""
        return f"DROP TABLE {ie}{self.compile(c, elem.path)}"

    def render_truncatetable(self, c: Compiler, elem: TruncateTable) -> str:
        return f"TRUNCATE TABLE {self.compile(c, elem.path)}"

    def render_inserttotable(self, c: Compiler, elem: InsertToTable) -> str:
        if isinstance(elem.expr, ConstantTable):
            expr = self.constant_values(elem.expr.rows)
        else:
            expr = self.compile(c, elem.expr)

        columns = "(%s)" % ", ".join(map(self.quote, elem.columns)) if elem.columns is not None else ""

        return f"INSERT INTO {self.compile(c, elem.path)}{columns} {expr}"

    def limit_select(
        self,
        select_query: str,
        offset: Optional[int] = None,
        limit: Optional[int] = None,
        has_order_by: Optional[bool] = None,
    ) -> str:
        if offset:
            raise NotImplementedError("No support for OFFSET in query")

        return f"SELECT * FROM ({select_query}) AS LIMITED_SELECT LIMIT {limit}"

    def concat(self, items: List[str]) -> str:
        "Provide SQL for concatenating a bunch of columns into a string"
        assert len(items) > 1
        joined_exprs = ", ".join(items)
        return f"concat({joined_exprs})"

    def to_comparable(self, value: str, coltype: ColType) -> str:
        """Ensure that the expression is comparable in ``IS DISTINCT FROM``."""
        return value

    def is_distinct_from(self, a: str, b: str) -> str:
        "Provide SQL for a comparison where NULL = NULL is true"
        return f"{a} is distinct from {b}"

    def timestamp_value(self, t: DbTime) -> str:
        "Provide SQL for the given timestamp value"
        return f"'{t.isoformat()}'"

    def random(self) -> str:
        "Provide SQL for generating a random number betweein 0..1"
        return "random()"

    def current_timestamp(self) -> str:
        "Provide SQL for returning the current timestamp, aka now"
        return "current_timestamp()"

    def current_database(self) -> str:
        "Provide SQL for returning the current default database."
        return "current_database()"

    def current_schema(self) -> str:
        "Provide SQL for returning the current default schema."
        return "current_schema()"

    def explain_as_text(self, query: str) -> str:
        "Provide SQL for explaining a query, returned as table(varchar)"
        return f"EXPLAIN {query}"

    def _constant_value(self, v):
        if v is None:
            return "NULL"
        elif isinstance(v, str):
            return f"'{v}'"
        elif isinstance(v, datetime):
            return self.timestamp_value(v)
        elif isinstance(v, UUID):  # probably unused anymore in favour of ArithUUID
            return f"'{v}'"
        elif isinstance(v, ArithUUID):
            return f"'{v.uuid}'"
        elif isinstance(v, decimal.Decimal):
            return str(v)
        elif isinstance(v, bytearray):
            return f"'{v.decode()}'"
        elif isinstance(v, Code):
            return v.code
        return repr(v)

    def constant_values(self, rows) -> str:
        values = ", ".join("(%s)" % ", ".join(self._constant_value(v) for v in row) for row in rows)
        return f"VALUES {values}"

    def type_repr(self, t) -> str:
        if isinstance(t, str):
            return t
        elif isinstance(t, TimestampTZ):
            return f"TIMESTAMP({min(t.precision, DEFAULT_DATETIME_PRECISION)})"
        return {
            int: "INT",
            str: "VARCHAR",
            bool: "BOOLEAN",
            float: "FLOAT",
            datetime: "TIMESTAMP",
        }[t]

    def parse_type(self, table_path: DbPath, info: RawColumnInfo) -> ColType:
        "Parse type info as returned by the database"

        cls = self.TYPE_CLASSES.get(info.data_type)
        if cls is None:
            return UnknownColType(info.data_type)

        if issubclass(cls, TemporalType):
            return cls(
                precision=info.datetime_precision
                if info.datetime_precision is not None
                else DEFAULT_DATETIME_PRECISION,
                rounds=self.ROUNDS_ON_PREC_LOSS,
            )

        elif issubclass(cls, Integer):
            return cls()

        elif issubclass(cls, Boolean):
            return cls()

        elif issubclass(cls, Decimal):
            if info.numeric_scale is None:
                return cls(precision=0)  # Needed for Oracle.
            return cls(precision=info.numeric_scale)

        elif issubclass(cls, Float):
            # assert numeric_scale is None
            return cls(
                precision=self._convert_db_precision_to_digits(
                    info.numeric_precision if info.numeric_precision is not None else DEFAULT_NUMERIC_PRECISION
                )
            )

        elif issubclass(cls, (JSON, Array, Struct, Text, Native_UUID)):
            return cls()

        raise TypeError(f"Parsing {info.data_type} returned an unknown type {cls!r}.")

    def _convert_db_precision_to_digits(self, p: int) -> int:
        """Convert from binary precision, used by floats, to decimal precision."""
        # See: https://en.wikipedia.org/wiki/Single-precision_floating-point_format
        return math.floor(math.log(2**p, 10))

    @property
    @abstractmethod
    def name(self) -> str:
        "Name of the dialect"

    @property
    @abstractmethod
    def ROUNDS_ON_PREC_LOSS(self) -> bool:
        "True if db rounds real values when losing precision, False if it truncates."

    @abstractmethod
    def quote(self, s: str):
        "Quote SQL name"

    @abstractmethod
    def to_string(self, s: str) -> str:
        # TODO rewrite using cast_to(x, str)
        "Provide SQL for casting a column to string"

    @abstractmethod
    def set_timezone_to_utc(self) -> str:
        "Provide SQL for setting the session timezone to UTC"

    @abstractmethod
    def md5_as_int(self, s: str) -> str:
        "Provide SQL for computing md5 and returning an int"

    @abstractmethod
    def md5_as_hex(self, s: str) -> str:
        """Method to calculate MD5"""

    @abstractmethod
    def normalize_timestamp(self, value: str, coltype: TemporalType) -> str:
        """Creates an SQL expression, that converts 'value' to a normalized timestamp.

        The returned expression must accept any SQL datetime/timestamp, and return a string.

        Date format: ``YYYY-MM-DD HH:mm:SS.FFFFFF``

        Precision of dates should be rounded up/down according to coltype.rounds
        e.g. precision 3 and coltype.rounds:
            - 1969-12-31 23:59:59.999999 -> 1970-01-01 00:00:00.000000
            - 1970-01-01 00:00:00.000888 -> 1970-01-01 00:00:00.001000
            - 1970-01-01 00:00:00.123123 -> 1970-01-01 00:00:00.123000

        Make sure NULLs remain NULLs
        """

    @abstractmethod
    def normalize_number(self, value: str, coltype: FractionalType) -> str:
        """Creates an SQL expression, that converts 'value' to a normalized number.

        The returned expression must accept any SQL int/numeric/float, and return a string.

        Floats/Decimals are expected in the format
        "I.P"

        Where I is the integer part of the number (as many digits as necessary),
        and must be at least one digit (0).
        P is the fractional digits, the amount of which is specified with
        coltype.precision. Trailing zeroes may be necessary.
        If P is 0, the dot is omitted.

        Note: We use 'precision' differently than most databases. For decimals,
        it's the same as ``numeric_scale``, and for floats, who use binary precision,
        it can be calculated as ``log10(2**numeric_precision)``.
        """

    def normalize_boolean(self, value: str, _coltype: Boolean) -> str:
        """Creates an SQL expression, that converts 'value' to either '0' or '1'."""
        return self.to_string(value)

    def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str:
        """Creates an SQL expression, that strips uuids of artifacts like whitespace."""
        if isinstance(coltype, String_UUID):
            return f"TRIM({value})"
        return self.to_string(value)

    def normalize_json(self, value: str, _coltype: JSON) -> str:
        """Creates an SQL expression, that converts 'value' to its minified json string representation."""
        return self.to_string(value)

    def normalize_array(self, value: str, _coltype: Array) -> str:
        """Creates an SQL expression, that serialized an array into a JSON string."""
        return self.to_string(value)

    def normalize_struct(self, value: str, _coltype: Struct) -> str:
        """Creates an SQL expression, that serialized a typed struct into a JSON string."""
        return self.to_string(value)

    def normalize_value_by_type(self, value: str, coltype: ColType) -> str:
        """Creates an SQL expression, that converts 'value' to a normalized representation.

        The returned expression must accept any SQL value, and return a string.

        The default implementation dispatches to a method according to `coltype`:

        ::

            TemporalType    -> normalize_timestamp()
            FractionalType  -> normalize_number()
            *else*          -> to_string()

            (`Integer` falls in the *else* category)

        """
        if isinstance(coltype, TemporalType):
            return self.normalize_timestamp(value, coltype)
        elif isinstance(coltype, FractionalType):
            return self.normalize_number(value, coltype)
        elif isinstance(coltype, ColType_UUID):
            return self.normalize_uuid(value, coltype)
        elif isinstance(coltype, Boolean):
            return self.normalize_boolean(value, coltype)
        elif isinstance(coltype, JSON):
            return self.normalize_json(value, coltype)
        elif isinstance(coltype, Array):
            return self.normalize_array(value, coltype)
        elif isinstance(coltype, Struct):
            return self.normalize_struct(value, coltype)
        return self.to_string(value)

    def optimizer_hints(self, hints: str) -> str:
        return f"/*+ {hints} */ "


T = TypeVar("T", bound=BaseDialect)
Row = Sequence[Any]


@attrs.define(frozen=True)
class QueryResult:
    rows: List[Row]
    columns: Optional[list] = None

    def __iter__(self) -> Iterator[Row]:
        return iter(self.rows)

    def __len__(self) -> int:
        return len(self.rows)

    def __getitem__(self, i) -> Row:
        return self.rows[i]


@attrs.define(frozen=False, kw_only=True)
class Database(abc.ABC):
    """Base abstract class for databases.

    Used for providing connection code and implementation specific SQL utilities.

    Instanciated using :meth:`~data_diff.connect`
    """

    DIALECT_CLASS: ClassVar[Type[BaseDialect]] = BaseDialect

    SUPPORTS_ALPHANUMS: ClassVar[bool] = True
    SUPPORTS_UNIQUE_CONSTAINT: ClassVar[bool] = False
    CONNECT_URI_KWPARAMS: ClassVar[List[str]] = []

    default_schema: Optional[str] = None
    _interactive: bool = False
    is_closed: bool = False
    _dialect: BaseDialect = None

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, traceback):
        self.close()

    @property
    def name(self):
        return type(self).__name__

    def compile(self, sql_ast):
        return self.dialect.compile(Compiler(self), sql_ast)

    def query(self, sql_ast: Union[Expr, Generator], res_type: type = None, log_message: Optional[str] = None):
        """Query the given SQL code/AST, and attempt to convert the result to type 'res_type'

        If given a generator, it will execute all the yielded sql queries with the same thread and cursor.
        The results of the queries a returned by the `yield` stmt (using the .send() mechanism).
        It's a cleaner approach than exposing cursors, but may not be enough in all cases.
        """

        compiler = Compiler(self)
        if isinstance(sql_ast, Generator):
            sql_code = ThreadLocalInterpreter(compiler, sql_ast)
        elif isinstance(sql_ast, list):
            for i in sql_ast[:-1]:
                self.query(i)
            return self.query(sql_ast[-1], res_type)
        else:
            if isinstance(sql_ast, str):
                sql_code = sql_ast
            else:
                if res_type is None:
                    res_type = sql_ast.type
                sql_code = self.compile(sql_ast)
                if sql_code is SKIP:
                    return SKIP

            if log_message:
                logger.debug("Running SQL (%s): %s \n%s", self.name, log_message, sql_code)
            else:
                logger.debug("Running SQL (%s):\n%s", self.name, sql_code)

        if self._interactive and isinstance(sql_ast, Select):
            explained_sql = self.compile(Explain(sql_ast))
            explain = self._query(explained_sql)
            for row in explain:
                # Most returned a 1-tuple. Presto returns a string
                if isinstance(row, tuple):
                    (row,) = row
                logger.debug("EXPLAIN: %s", row)
            answer = input("Continue? [y/n] ")
            if answer.lower() not in ["y", "yes"]:
                sys.exit(1)

        res = self._query(sql_code)
        if res_type is list:
            return list(res)
        elif res_type is int:
            if not res:
                raise ValueError("Query returned 0 rows, expected 1")
            row = _one(res)
            if not row:
                raise ValueError("Row is empty, expected 1 column")
            res = _one(row)
            if res is None:  # May happen due to sum() of 0 items
                return None
            return int(res)
        elif res_type is datetime:
            res = _one(_one(res))
            if isinstance(res, str):
                res = datetime.fromisoformat(res[:23])  # TODO use a better parsing method
            return res
        elif res_type is tuple:
            assert len(res) == 1, (sql_code, res)
            return res[0]
        elif getattr(res_type, "__origin__", None) is list and len(res_type.__args__) == 1:
            if res_type.__args__ in ((int,), (str,)):
                return [_one(row) for row in res]
            elif res_type.__args__ in [(Tuple,), (tuple,)]:
                return [tuple(row) for row in res]
            elif res_type.__args__ == (dict,):
                return [dict(safezip(res.columns, row)) for row in res]
            else:
                raise ValueError(res_type)
        return res

    def enable_interactive(self):
        self._interactive = True

    def select_table_schema(self, path: DbPath) -> str:
        """Provide SQL for selecting the table schema as (name, type, date_prec, num_prec)"""
        schema, name = self._normalize_table_path(path)

        return (
            "SELECT column_name, data_type, datetime_precision, numeric_precision, numeric_scale "
            "FROM information_schema.columns "
            f"WHERE table_name = '{name}' AND table_schema = '{schema}'"
        )

    def query_table_schema(self, path: DbPath) -> Dict[str, RawColumnInfo]:
        """Query the table for its schema for table in 'path', and return {column: tuple}
        where the tuple is (table_name, col_name, type_repr, datetime_precision?, numeric_precision?, numeric_scale?)

        Note: This method exists instead of select_table_schema(), just because not all databases support
              accessing the schema using a SQL query.
        """
        rows = self.query(self.select_table_schema(path), list, log_message=path)

        if not rows:
            raise RuntimeError(f"{self.name}: Table '{'.'.join(path)}' does not exist, or has no columns")

        d = {
            r[0]: RawColumnInfo(
                column_name=r[0],
                data_type=r[1],
                datetime_precision=r[2],
                numeric_precision=r[3],
                numeric_scale=r[4],
                collation_name=r[5] if len(r) > 5 else None,
            )
            for r in rows
        }

        assert len(d) == len(rows)
        return d

    def select_table_unique_columns(self, path: DbPath) -> str:
        """Provide SQL for selecting the names of unique columns in the table"""
        schema, name = self._normalize_table_path(path)

        return (
            "SELECT column_name "
            "FROM information_schema.key_column_usage "
            f"WHERE table_name = '{name}' AND table_schema = '{schema}'"
        )

    def query_table_unique_columns(self, path: DbPath) -> List[str]:
        """Query the table for its unique columns for table in 'path', and return {column}"""
        if not self.SUPPORTS_UNIQUE_CONSTAINT:
            raise NotImplementedError("This database doesn't support 'unique' constraints")
        res = self.query(self.select_table_unique_columns(path), List[str], log_message=path)
        return list(res)

    def _process_table_schema(
        self,
        path: DbPath,
        raw_schema: Dict[str, RawColumnInfo],
        filter_columns: Sequence[str] = None,
        where: str = None,
    ):
        """Process the result of query_table_schema().

        Done in a separate step, to minimize the amount of processed columns.
        Needed because processing each column may:
        * throw errors and warnings
        * query the database to sample values

        """
        if filter_columns is None:
            filtered_schema = raw_schema
        else:
            accept = {i.lower() for i in filter_columns}
            filtered_schema = {name: row for name, row in raw_schema.items() if name.lower() in accept}

        col_dict = {info.column_name: self.dialect.parse_type(path, info) for info in filtered_schema.values()}

        self._refine_coltypes(path, col_dict, where)

        # Return a dict of form {name: type} after normalization
        return col_dict

    def _refine_coltypes(
        self, table_path: DbPath, col_dict: Dict[str, ColType], where: Optional[str] = None, sample_size=64
    ) -> Dict[str, ColType]:
        """Refine the types in the column dict, by querying the database for a sample of their values

        'where' restricts the rows to be sampled.
        """

        text_columns = [k for k, v in col_dict.items() if isinstance(v, Text)]
        if not text_columns:
            return col_dict

        fields = [Code(self.dialect.normalize_uuid(self.dialect.quote(c), String_UUID())) for c in text_columns]

        samples_by_row = self.query(
            table(*table_path).select(*fields).where(Code(where) if where else SKIP).limit(sample_size),
            list,
            log_message=table_path,
        )
        samples_by_col = list(zip(*samples_by_row)) if samples_by_row else [[]] * len(text_columns)
        for col_name, samples in safezip(text_columns, samples_by_col):
            uuid_samples = [s for s in samples if s and is_uuid(s)]

            if uuid_samples:
                if len(uuid_samples) != len(samples):
                    logger.warning(
                        f"Mixed UUID/Non-UUID values detected in column {'.'.join(table_path)}.{col_name}, disabling UUID support."
                    )
                else:
                    assert col_name in col_dict
                    col_dict[col_name] = String_UUID(
                        lowercase=all(s == s.lower() for s in uuid_samples),
                        uppercase=all(s == s.upper() for s in uuid_samples),
                    )
                    continue

            if self.SUPPORTS_ALPHANUMS:  # Anything but MySQL (so far)
                alphanum_samples = [s for s in samples if String_Alphanum.test_value(s)]
                if alphanum_samples:
                    if len(alphanum_samples) != len(samples):
                        logger.debug(
                            f"Mixed Alphanum/Non-Alphanum values detected in column {'.'.join(table_path)}.{col_name}. It cannot be used as a key."
                        )
                    else:
                        assert col_name in col_dict
                        col_dict[col_name] = String_VaryingAlphanum(collation=col_dict[col_name].collation)

        return col_dict

    def _normalize_table_path(self, path: DbPath) -> DbPath:
        if len(path) == 1:
            return self.default_schema, path[0]
        elif len(path) == 2:
            return path

        raise ValueError(f"{self.name}: Bad table path for {self}: '{'.'.join(path)}'. Expected form: schema.table")

    def _query_cursor(self, c, sql_code: str) -> QueryResult:
        assert isinstance(sql_code, str), sql_code
        try:
            c.execute(sql_code)
            if sql_code.lower().startswith(("select", "explain", "show")):
                columns = [col[0] for col in c.description]

                fetched = c.fetchall()
                result = QueryResult(fetched, columns)
                return result
        except Exception as _e:
            # logger.exception(e)
            # logger.error(f"Caused by SQL: {sql_code}")
            raise

    def _query_conn(self, conn, sql_code: Union[str, ThreadLocalInterpreter]) -> QueryResult:
        c = conn.cursor()
        callback = partial(self._query_cursor, c)
        return apply_query(callback, sql_code)

    def close(self):
        """Close connection(s) to the database instance. Querying will stop functioning."""
        self.is_closed = True

    @property
    def dialect(self) -> BaseDialect:
        "The dialect of the database. Used internally by Database, and also available publicly."

        if not self._dialect:
            self._dialect = self.DIALECT_CLASS()
        return self._dialect

    @property
    @abstractmethod
    def CONNECT_URI_HELP(self) -> str:
        "Example URI to show the user in help and error messages"

    @property
    @abstractmethod
    def CONNECT_URI_PARAMS(self) -> List[str]:
        "List of parameters given in the path of the URI"

    @abstractmethod
    def _query(self, sql_code: str) -> list:
        "Send query to database and return result"

    @property
    @abstractmethod
    def is_autocommit(self) -> bool:
        "Return whether the database autocommits changes. When false, COMMIT statements are skipped."


@attrs.define(frozen=False)
class ThreadedDatabase(Database):
    """Access the database through singleton threads.

    Used for database connectors that do not support sharing their connection between different threads.
    """

    thread_count: int = 1

    _init_error: Optional[Exception] = None
    _queue: Optional[ThreadPoolExecutor] = None
    thread_local: threading.local = attrs.field(factory=threading.local)

    def __attrs_post_init__(self) -> None:
        self._queue = ThreadPoolExecutor(self.thread_count, initializer=self.set_conn)
        logger.info(f"[{self.name}] Starting a threadpool, size={self.thread_count}.")

    def set_conn(self):
        assert not hasattr(self.thread_local, "conn")
        try:
            self.thread_local.conn = self.create_connection()
        except Exception as e:
            self._init_error = e

    def _query(self, sql_code: Union[str, ThreadLocalInterpreter]) -> QueryResult:
        r = self._queue.submit(self._query_in_worker, sql_code)
        return r.result()

    def _query_in_worker(self, sql_code: Union[str, ThreadLocalInterpreter]):
        """This method runs in a worker thread"""
        if self._init_error:
            raise self._init_error
        return self._query_conn(self.thread_local.conn, sql_code)

    @abstractmethod
    def create_connection(self):
        """Return a connection instance, that supports the .cursor() method."""

    def close(self):
        super().close()
        self._queue.shutdown()
        if hasattr(self.thread_local, "conn"):
            self.thread_local.conn.close()

    @property
    def is_autocommit(self) -> bool:
        return False


CHECKSUM_HEXDIGITS = 12  # Must be 12 or lower, otherwise SUM() overflows
MD5_HEXDIGITS = 32

_CHECKSUM_BITSIZE = CHECKSUM_HEXDIGITS << 2
CHECKSUM_MASK = (2**_CHECKSUM_BITSIZE) - 1

# bigint is typically 8 bytes
# if checksum is shorter, most databases will pad it with zeros
# 0xFF → 0x00000000000000FF;
# because of that, the numeric representation is always positive,
# which limits the number of checksums that we can add together before overflowing.
# we can fix that by adding a negative offset of half the max value,
# so that the distribution is from -0.5*max to +0.5*max.
# then negative numbers can compensate for the positive ones allowing to add more checksums together
# without overflowing.
CHECKSUM_OFFSET = CHECKSUM_MASK // 2

DEFAULT_DATETIME_PRECISION = 6
DEFAULT_NUMERIC_PRECISION = 24

TIMESTAMP_PRECISION_POS = 20  # len("2022-06-03 12:24:35.") == 20
